{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "a2e55bff",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/agent/openai_forced_function_call.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56e895e8",
   "metadata": {},
   "source": [
    "# OpenAI agent: specifying a forced function call"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "17143246",
   "metadata": {},
   "source": [
    "If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de42f955",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install llama-index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3789f9b2-3095-4257-b1bf-874456f8d560",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from typing import Sequence, List\n",
    "\n",
    "from llama_index.llms import OpenAI, ChatMessage\n",
    "from llama_index.tools import BaseTool, FunctionTool\n",
    "from llama_index.agent import OpenAIAgent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73820afc-1e00-4f36-957c-9e59572016d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def add(a: int, b: int) -> int:\n",
    "    \"\"\"Add two integers and returns the result integer\"\"\"\n",
    "    return a + b\n",
    "\n",
    "\n",
    "add_tool = FunctionTool.from_defaults(fn=add)\n",
    "\n",
    "\n",
    "def useless_tool() -> int:\n",
    "    \"\"\"This is a uselss tool.\"\"\"\n",
    "    return \"This is a uselss output.\"\n",
    "\n",
    "\n",
    "useless_tool = FunctionTool.from_defaults(fn=useless_tool)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06084e32-4895-4c97-9855-354b293962f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = OpenAI(model=\"gpt-3.5-turbo-0613\")\n",
    "agent = OpenAIAgent.from_tools([useless_tool, add_tool], llm=llm, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34b44bc7-59b9-40ce-9b19-eb437adf6788",
   "metadata": {},
   "source": [
    "### \"Auto\" function call"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7a56538-fbf0-49a8-89dc-b8813ee8653e",
   "metadata": {},
   "source": [
    "The agent automatically selects the useful \"add\" tool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "749b26c8-8377-4ffe-bff9-89e913fde619",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "STARTING TURN 1\n",
      "---------------\n",
      "\n",
      "=== Calling Function ===\n",
      "Calling function: add with args: {\n",
      "  \"a\": 5,\n",
      "  \"b\": 2\n",
      "}\n",
      "Got output: 7\n",
      "========================\n",
      "\n",
      "STARTING TURN 2\n",
      "---------------\n",
      "\n"
     ]
    }
   ],
   "source": [
    "response = agent.chat(\n",
    "    \"What is 5 + 2?\", tool_choice=\"auto\"\n",
    ")  # note function_call param is deprecated\n",
    "# use tool_choice instead"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e81e08c2-26ce-4234-b184-2b1922b1715d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The sum of 5 and 2 is 7.\n"
     ]
    }
   ],
   "source": [
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4855b6f2-8b68-4fae-9fe2-0ed8a7cb193b",
   "metadata": {},
   "source": [
    "### Forced function call"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a6d06ea-b744-426c-82a6-f67a74581474",
   "metadata": {},
   "source": [
    "The agent is forced to call the \"useless_tool\" before selecting the \"add\" tool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2884742-0062-4413-8dce-8fcc43971a1d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "STARTING TURN 1\n",
      "---------------\n",
      "\n",
      "=== Calling Function ===\n",
      "Calling function: useless_tool with args: {}\n",
      "Got output: This is a uselss output.\n",
      "========================\n",
      "\n",
      "STARTING TURN 2\n",
      "---------------\n",
      "\n",
      "=== Calling Function ===\n",
      "Calling function: add with args: {\n",
      "  \"a\": 5,\n",
      "  \"b\": 2\n",
      "}\n",
      "Got output: 7\n",
      "========================\n",
      "\n",
      "STARTING TURN 3\n",
      "---------------\n",
      "\n"
     ]
    }
   ],
   "source": [
    "response = agent.chat(\"What is 5 * 2?\", tool_choice=\"useless_tool\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67b7f4dd-1284-45d5-a8c3-6f5482f54c00",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The product of 5 and 2 is 10.\n"
     ]
    }
   ],
   "source": [
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2c83bd3-2d83-48b7-b925-7837ac941719",
   "metadata": {},
   "source": [
    "### \"None\" function call"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ced85bc0-cbe7-484e-b3e6-8a326cb1c4d8",
   "metadata": {},
   "source": [
    "The agent is forced to not use a tool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "935b9b4b-1ba2-4fdd-ab9d-07cd9bf081bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "STARTING TURN 1\n",
      "---------------\n",
      "\n"
     ]
    }
   ],
   "source": [
    "response = agent.chat(\"What is 5 * 2?\", tool_choice=\"none\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c63b496-3f08-4240-a415-4eca7f331847",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The product of 5 and 2 is 10.\n"
     ]
    }
   ],
   "source": [
    "print(response)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
